{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f61ac45f",
   "metadata": {},
   "source": [
    "# SFT_GRPO_4: Reward hacking"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45271196",
   "metadata": {},
   "source": [
    "Start by importing dependencies and setting up two clients, one for OpenAI and one for Predibase:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43e4f9a0-51a4-4faa-8ca0-40ec6f58c483",
   "metadata": {
    "height": 268
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from datasets import load_dataset\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "\n",
    "from utils import *\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
    "\n",
    "pb_client = OpenAI(\n",
    "    base_url=os.environ[\"PREDIBASE_MODEL_LLAMA_URL\"],\n",
    "    api_key=os.environ[\"PREDIBASE_API_KEY\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50e0e055",
   "metadata": {},
   "source": [
    "## Hacking the summarization task with longer summaries\n",
    "\n",
    "Here, you'll see how longer summaries could lead to higher rewards. Start by loading the same earnings call dataset from the previous lesson:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4bbc31a-3226-48aa-86f7-b8464d6f1eda",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "ds = load_dataset(\"mrSoul7766/ECTSum\")\n",
    "transcript = ds[\"train\"][1][\"text\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b20217a8",
   "metadata": {},
   "source": [
    "Generate a quiz based on the call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1dfee7-c3b7-42ba-80f7-cd064c9d8e66",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "quiz = generate_quiz(transcript)\n",
    "print(quiz)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39f57f0d",
   "metadata": {},
   "source": [
    "Generate 8 summaries of the call (again, you'll use the Llama-3.1-8B-Instruct-dequantized, which is defined in the utils.py file):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a468c80b-46be-4156-98a8-de14ea9bc9c3",
   "metadata": {
    "height": 319
   },
   "outputs": [],
   "source": [
    "prompt = f\"\"\"Generate a concise bulleted summary of the \n",
    "information in the following earnings call transcript.\n",
    "\n",
    "Only respond with the summary, do not include any extraneous text.\n",
    "\n",
    "Transcript:\n",
    "\n",
    "{transcript}\n",
    "\"\"\"\n",
    "\n",
    "completions = pb_client.chat.completions.create(\n",
    "    model=MODEL_NAME,\n",
    "    messages=[\n",
    "        {\"role\": \"user\", \"content\": prompt},\n",
    "    ],\n",
    "    n=8,\n",
    "    temperature=0.9,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97d48620",
   "metadata": {},
   "source": [
    "Use each summary to take the quiz and get a reward score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f3d3f9f-b7aa-4f00-a5ff-4878c9d03fab",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "responses = [choice.message.content for choice in completions.choices]\n",
    "quiz_rewards = [quiz_reward(resp, quiz) for resp in responses]\n",
    "quiz_rewards"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbc08eeb",
   "metadata": {},
   "source": [
    "The transcript should get a perfect score: check that it does:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "837342ef-4fa7-4fc0-98d2-410fe61c35a5",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "transcript_score = quiz_reward(transcript, quiz)\n",
    "transcript_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ce2e3be",
   "metadata": {},
   "source": [
    "Check lengths of the 8 summaries and compare to full transcript:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c27c608a-fd8a-4427-9f4f-494025932be8",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "lengths = [len(resp) for resp in responses]\n",
    "lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7880ef55-f4b4-48cd-bb6d-a91ef2e5e680",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "len(transcript)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31e1ec4e",
   "metadata": {},
   "source": [
    "## Create a penalty function to discourage longer summaries\n",
    "\n",
    "Here, you'll create a reward function that assigns a negative score (i.e. a penalty) to the model for longer summaries. Over time during training, this penalty should discourage the model from getting higher quiz scores just by writing longer summaries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05f378b-c2d1-47bf-b424-c62ce5b9a485",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "def length_penalty_reward(response: str) -> float:\n",
    "    length = len(response)\n",
    "    target_length = 1024\n",
    "    if length <= target_length:\n",
    "        return 0.0\n",
    "    else:\n",
    "        return max(\n",
    "            (target_length - length) / target_length,\n",
    "            -10\n",
    "        ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1c99ab-8ec2-457d-9278-1f4e76090318",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "transcript_reward = length_penalty_reward(transcript)\n",
    "transcript_reward"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c499bc3b",
   "metadata": {},
   "source": [
    "Show the length penalties and resulting advantages for the 8 summaries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5f5b87c-70ba-4abc-9efc-9fd27bd32ea2",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "lengths = [len(resp) for resp in responses]\n",
    "length_rewards = [\n",
    "    length_penalty_reward(resp) for resp in responses\n",
    "]\n",
    "print_length_table(lengths, length_rewards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "676e73f6",
   "metadata": {},
   "source": [
    "Add the length penalty to the quiz reward score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96cccf1-0049-45dd-8e32-0a7edf616d98",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "def total_reward(length_reward, quiz_reward):\n",
    "    return length_reward + quiz_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "584cb7d6-d561-4691-b4f7-4fbdb798d7e2",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "total_rewards = [\n",
    "    total_reward(length_reward, quiz_reward) \n",
    "    for length_reward, quiz_reward\n",
    "    in zip(length_rewards, quiz_rewards)\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9ffe100",
   "metadata": {},
   "source": [
    "Visualize the trade-off between length and quiz score in determining advantages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a97364d5-5da7-4b20-b18a-c0986b554e5a",
   "metadata": {
    "height": 283
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "advantages = compute_advantages(total_rewards)\n",
    "min_adv = min(advantages)\n",
    "max_adv = max(advantages)\n",
    "\n",
    "plt.figure(figsize=(10,6), facecolor='black')\n",
    "plt.style.use('dark_background')\n",
    "scatter = plt.scatter(length_rewards, quiz_rewards, c=advantages, cmap='RdYlGn', s=100, edgecolor='white', vmin=min_adv, vmax=max_adv)\n",
    "plt.colorbar(scatter, label='Advantage')\n",
    "plt.xlabel('Length Reward')\n",
    "plt.ylabel('Quiz Reward')\n",
    "plt.title('Length Reward vs Quiz Reward (colored by advantage)')\n",
    "plt.grid(True, alpha=0.2)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
